# coding=utf8
import re
import time

from bson import ObjectId

from libs.sele import driver_new_session
from models.market import model_market, model_market_log
from libs.dbs import model_redis, CO_STATICS_RKEY

re_pattern_num = r'[1-9]\d*\.\d*|0\.\d*[1-9]\d*|[0-9]\d*'

tr_map = {'Africa': 'TR128',
          'Central America': 'TR512',
          'Domestic Market': 'TR10010',
          'Eastern Asia': 'TR16',
          'Eastern Europe': 'TR8',
          'Mid East': 'TR64',
          'North America': 'TR1',
          'Northern Europe': 'TR1024',
          'Oceania': 'TR256',
          'South America': 'TR2',
          'South Asia': 'TR5096',
          'Southeast Asia': 'TR32',
          'Southern Europe': 'TR2048',
          'Western Europe': 'TR4'}


def get_list(pid):
    print(driver.current_url)
    for item in driver.find_elements_by_css_selector("div.f-icon.m-item"):
        co_name = item.find_element_by_css_selector('.title a').get_attribute('innerHTML')
        print(co_name)
        product_list = []
        for product in item.find_elements_by_css_selector('.product'):
            product_item = product.find_element_by_tag_name('img')
            product_list.append({
                'Link': product.find_element_by_tag_name('a').get_attribute('href'),
                'Thumb': product_item.get_attribute('src'),
                'Title': product_item.get_attribute('alt')
            })
        market_list = []
        revenue = 0
        for market_div in item.find_elements_by_css_selector('.attr'):
            market_title = market_div.find_element_by_css_selector('.name').get_attribute('innerHTML')
            if market_title == 'Top 3 Markets:':
                for market in market_div.find_elements_by_css_selector('.value span'):
                    market_info = market.get_attribute('innerHTML').split()
                    market_list.append({
                        'Market': ' '.join(market_info[:-1]),
                        'Share': int(market_info[-1].replace('%', '')),
                        'Money': int(market_info[-1].replace('%', '')) / 100 * revenue
                    })
            if market_title == 'Total Revenue:':
                revenue_str = market_div.find_element_by_css_selector('.value span').get_attribute('innerHTML')
                revenue = float(re.findall(re_pattern_num, revenue_str)[-1]) * 1000000
        model_market.insert_one({
            'Pid': pid,
            'CoName': co_name,
            'Products': product_list,
            'Market': market_list,
            'Revenue': revenue
        })
        yield market_list


def statics_martket(market_statics, market_list):
    for shop in market_list:
        if shop['Market'] in market_statics:
            market_statics[shop['Market']] += shop['Money']
        else:
            market_statics[shop['Market']] = shop['Money']


def crawl(keyword):
    pid = ObjectId()
    market_statics = {}
    while True:
        time.sleep(3)
        for market_list in get_list(pid):
            statics_martket(market_statics, market_list)
        try:
            next_buttom = driver.find_element_by_css_selector('.next')
            if 'disable' in next_buttom.get_attribute('class'):
                print("没有下一页了")
                break
            next_buttom.click()
            print("翻页")
        except:
            print("找翻页按钮出错= =")
            break

    # 爬取完成,开始统计
    market_share = {}
    for shop in model_market.find({'Pid': pid}):
        for market in shop['Market']:
            if market['Market'] not in market_share:
                market_share[market['Market']] = {}
            market_share[market['Market']][shop['CoName'].replace(',', '')] = str(
                market['Money'] / market_statics[market['Market']])
    for market, info in market_share.items():
        res = [
            '公司名称,份额'
        ]
        for co_name, share in info.items():
            res.append(co_name + ',' + share)
        with open("data/{keyword}_{market}_co_statics.csv".format(keyword=keyword, market=market), 'w',
                  encoding='utf-8-sig') as f:
            f.write(
                "\n".join(res)
            )


if __name__ == '__main__':
    # 之后传url进来
    url_list = [
        # "http://www.alibaba.com/trade/search?IndexArea=company_en&SearchText=Aphrodisiac&c=CID100009245&Country=CN&atm=&f0=y&top3_markets=TR2"
        {
            'url': 'http://www.alibaba.com/corporations/Aphrodisiac/CID100009245--CN------------------50.html',
            'keyword': 'Aphrodisiac',
            'type': 'all'
        }
    ]

    driver = driver_new_session()

    for url_item in url_list:
        driver.get(url_item['url'])
        crawl(url_item['keyword'])

    driver.close()
